{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import collections\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import re\n",
    "\n",
    "from argparse import Namespace"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = Namespace(\n",
    "    raw_dataset_csv=\"data/ag_news/news.csv\",\n",
    "    train_proportion=0.7,\n",
    "    val_proportion=0.15,\n",
    "    test_proportion=0.15,\n",
    "    output_munged_csv=\"data/ag_news/news_with_splits.csv\",\n",
    "    seed=1337\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Read raw data\n",
    "news = pd.read_csv(args.raw_dataset_csv, header=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>category</th>\n",
       "      <th>title</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Business</td>\n",
       "      <td>Wall St. Bears Claw Back Into the Black (Reuters)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Business</td>\n",
       "      <td>Carlyle Looks Toward Commercial Aerospace (Reu...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Business</td>\n",
       "      <td>Oil and Economy Cloud Stocks' Outlook (Reuters)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Business</td>\n",
       "      <td>Iraq Halts Oil Exports from Main Southern Pipe...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Business</td>\n",
       "      <td>Oil prices soar to all-time record, posing new...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   category                                              title\n",
       "0  Business  Wall St. Bears Claw Back Into the Black (Reuters)\n",
       "1  Business  Carlyle Looks Toward Commercial Aerospace (Reu...\n",
       "2  Business    Oil and Economy Cloud Stocks' Outlook (Reuters)\n",
       "3  Business  Iraq Halts Oil Exports from Main Southern Pipe...\n",
       "4  Business  Oil prices soar to all-time record, posing new..."
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "news.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'Business', 'Sci/Tech', 'Sports', 'World'}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Unique classes\n",
    "set(news.category)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Splitting train by category\n",
    "# Create dict\n",
    "by_category = collections.defaultdict(list)\n",
    "for _, row in news.iterrows():\n",
    "    by_category[row.category].append(row.to_dict())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create split data\n",
    "final_list = []\n",
    "np.random.seed(args.seed)\n",
    "for _, item_list in sorted(by_category.items()):\n",
    "    np.random.shuffle(item_list)\n",
    "    n = len(item_list)\n",
    "    n_train = int(args.train_proportion*n)\n",
    "    n_val = int(args.val_proportion*n)\n",
    "    n_test = int(args.test_proportion*n)\n",
    "    \n",
    "    # Give data point a split attribute\n",
    "    for item in item_list[:n_train]:\n",
    "        item['split'] = 'train'\n",
    "    for item in item_list[n_train:n_train+n_val]:\n",
    "        item['split'] = 'val'\n",
    "    for item in item_list[n_train+n_val:]:\n",
    "        item['split'] = 'test'  \n",
    "    \n",
    "    # Add to final list\n",
    "    final_list.extend(item_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Write split data to file\n",
    "final_news = pd.DataFrame(final_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "train    84000\n",
       "test     18000\n",
       "val      18000\n",
       "Name: split, dtype: int64"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "final_news.split.value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Preprocess the reviews\n",
    "def preprocess_text(text):\n",
    "    text = ' '.join(word.lower() for word in text.split(\" \"))\n",
    "    text = re.sub(r\"([.,!?])\", r\" \\1 \", text)\n",
    "    text = re.sub(r\"[^a-zA-Z.,!?]+\", r\" \", text)\n",
    "    return text\n",
    "    \n",
    "final_news.title = final_news.title.apply(preprocess_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>category</th>\n",
       "      <th>split</th>\n",
       "      <th>title</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Business</td>\n",
       "      <td>train</td>\n",
       "      <td>jobs , tax cuts key issues for bush</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Business</td>\n",
       "      <td>train</td>\n",
       "      <td>jarden buying mr . coffee s maker</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Business</td>\n",
       "      <td>train</td>\n",
       "      <td>retail sales show festive fervour</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Business</td>\n",
       "      <td>train</td>\n",
       "      <td>intervoice s customers come calling</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Business</td>\n",
       "      <td>train</td>\n",
       "      <td>boeing expects air force contract</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   category  split                                title\n",
       "0  Business  train  jobs , tax cuts key issues for bush\n",
       "1  Business  train    jarden buying mr . coffee s maker\n",
       "2  Business  train    retail sales show festive fervour\n",
       "3  Business  train  intervoice s customers come calling\n",
       "4  Business  train    boeing expects air force contract"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "final_news.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Write munged data to CSV\n",
    "final_news.to_csv(args.output_munged_csv, index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "nlpbook",
   "language": "python",
   "name": "nlpbook"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  },
  "toc": {
   "colors": {
    "hover_highlight": "#DAA520",
    "running_highlight": "#FF0000",
    "selected_highlight": "#FFD700"
   },
   "moveMenuLeft": true,
   "nav_menu": {
    "height": "12px",
    "width": "252px"
   },
   "navigate_menu": true,
   "number_sections": true,
   "sideBar": true,
   "threshold": "5",
   "toc_cell": false,
   "toc_section_display": "block",
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
